##############
import DonaldDuckDataset
from foolbox4attack import attackMethods
from DonaldDuckDRR import DRR
from DonaldDuckEn_De_R import En_De_R
from DonaldDuckDe_R import De_R
import tensorflow as tf
import DonaldDuckConv
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from Get_Detector import detectModels

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
physical_devices = tf.config.list_physical_devices('GPU')
try:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    assert tf.config.experimental.get_memory_growth(physical_devices[0])
except:
    pass
tf.random.set_seed(
    123
)

# you could change dataset and victim model by commenting directly on the code
dataset=DonaldDuckDataset.CIFAR10(standardization=False)
# dataset=DonaldDuckDataset.Fashion(standardization=False)
# dataset=DonaldDuckDataset.MNIST(standardization=False)

def importAdv(adv_img_path):
    img_cols = dataset.input_shape[0]
    img_rows = dataset.input_shape[1]
    channels = dataset.input_shape[2]

    testAdv = np.array(pd.read_csv(adv_img_path, usecols=range(1, img_cols * img_rows * channels + 1)))
    testAdv = testAdv.reshape((testAdv.shape[0], img_rows, img_cols, channels))
    return testAdv


if dataset.name=='cifar10':
    tar_model = DonaldDuckConv.DonaldDuckVGG16(
        dataset,
        #build_dir=False
    )
    tar_model.setModel()
    tar_model.load_model(
        weights_path=r'savedModels//' + 'VGG16' +
                     '_' + dataset.name+ '.h5'
    )
else:
    conv_layers_num = 5
    init_filters = 32
    tar_model = DonaldDuckConv.DonaldDuckCNN(
         dataset,
         #build_dir=False
    )
    tar_model.setModel(
         conv_layers_num=conv_layers_num,
         filters=init_filters,
         kernel_size=(3,3)
    )
    tar_model.load_model(
         weights_path=r'savedModels//'+'CNN'+'_'+dataset.name
                                  +'_'+str(conv_layers_num)
                                  +'_'+str(init_filters)+'.h5'
    )

d_model_list=[]
d_model_name=['DRR','HVR-P','HVR-L','HLR-P','HLR-L']

for dms in detectModels:
    dgan=detectModels[dms]['model'](
        dataset,
        batch_size=128,
        epochs=50,
        kernel_size=(3,3),
        build_dir=False
    )
    dgan.setModel(
        tar_model=tar_model,
        skip_flag=False,
    )
    weight_date=detectModels[dms][dataset.name]['dir']
    weight_idx=str(detectModels[dms][dataset.name]['idx'])
    model_name=detectModels[dms][dataset.name]['name']
    dgan.loadWeights(
         encoder_weight_path='savedModels//'+weight_date+
                             '//weight_encoder_'+model_name+
                             weight_idx+'.h5',
         decoder_weight_path='savedModels//'+weight_date+
                             '//weight_decoder_'+model_name+
                             weight_idx+'.h5',
         disI_weight_path='savedModels//'+weight_date+
                         '//weight_dis_'+model_name+
                         weight_idx+'.h5',
    )
    d_model_list.append(dgan)

def plot_img(data,title, first_flag=False, num_img=10,fontsize=26):
    print(title.replace('\n','_'))
    idx=np.random.randint(0, data.shape[0], num_img)
    test_imgs=data[idx]
    for idx in range(num_img):
        img=test_imgs[idx].reshape((1,)+dataset.input_shape)
        plt.figure(figsize=(12,2))
        plt.subplot(1,len(d_model_list)+1,1)
        plt.xticks([], [])
        plt.yticks([], [])
        plt.ylabel(title, fontsize=fontsize)
        if img[0].shape[2]==3:
            plt.imshow(img[0])
        else:
            plt.imshow(img[0,:,:,0], cmap='gray')
            
        for idy in range(len(d_model_list)):
            re_img=d_model_list[idy].reconstruct(img)[0]
            plt.subplot(1,len(d_model_list)+1,idy+2)
            plt.xticks([], [])
            plt.yticks([], [])
            if first_flag:
                plt.title(d_model_name[idy], fontsize=fontsize)
            if re_img.shape[2]==3:
                plt.imshow(re_img)
            else:
                plt.imshow(re_img[:,:,0], cmap='gray')
        plt.tight_layout()
        plt.savefig(tar_model.saveImgPath + '/' +title.replace('\n','_')+'_'+ str(idx) + '.pdf')
        #plt.show()
        plt.clf()
        plt.close()

def plot_img2(data,title, first_flag=False, num_img=10,fontsize=26):
    print(title.replace('\n','_'))
    idx=np.random.randint(0, data.shape[0], 300)
    imgs=data[idx]
    re_imgs=[]
    for idx in range(len(d_model_list)):
        re_imgs.append(d_model_list[idx].reconstruct(imgs))
    for idx in range(num_img):
        img=imgs[idx]
        plt.figure(figsize=(12,2))
        plt.subplot(1,len(d_model_list)+1,1)
        plt.xticks([], [])
        plt.yticks([], [])
        plt.ylabel(title, fontsize=fontsize)
        if img.shape[2]==3:
            plt.imshow(img)
        else:
            plt.imshow(img[:,:,0], cmap='gray')
            
        for idy in range(len(d_model_list)):
            re_img=re_imgs[idy][idx]
            plt.subplot(1,len(d_model_list)+1,idy+2)
            plt.xticks([], [])
            plt.yticks([], [])
            if first_flag:
                plt.title(d_model_name[idy], fontsize=fontsize)
            if re_img.shape[2]==3:
                plt.imshow(re_img)
            else:
                plt.imshow(re_img[:,:,0], cmap='gray')
        plt.tight_layout()
        plt.savefig(tar_model.saveImgPath + '/' +title.replace('\n','_')+'_'+ str(idx) + '.pdf')
        #plt.show()
        plt.clf()
        plt.close()

def plot_all_image():
    num_img=12
    data=dataset.x_test
    plot_img2(data,'benign\n \n ',first_flag=True, num_img=num_img)
    for ams in attackMethods:
        for am in attackMethods[ams]:
            epsilons = attackMethods[ams][am]['epsilon'][dataset.name]
            for epsilon in epsilons:
                adv_img_path = 'data//' + dgan.tar_model.name \
                               + '-adv-' + ams + '_' + am + '_' + str(epsilon) + '.csv'
                testAdv=importAdv(adv_img_path)
                advname=ams + '\n' + am
                if ams!='DeepFool' and ams!='CW':
                    advname=advname + '\n' + str(epsilon)
                else:
                    advname=advname + '\n' + ' '
                plot_img(testAdv,advname, num_img=num_img)
                
def plot_box():
    plt.style.use('seaborn-whitegrid')
    for ams in attackMethods:
        for am in attackMethods[ams]:
            epsilons = attackMethods[ams][am]['epsilon'][dataset.name]
            for epsilon in epsilons:
                dl_raw=[]
                dl_adv=[]
                for idx in range(len(d_model_list)):
                    dgan=d_model_list[idx]
                    clean_img_path = 'data//' + dgan.tar_model.name \
                                     + '-clean-' + ams + '_' + am + '_' + str(epsilon) + '.csv'#_1000
                    adv_img_path = 'data//' + dgan.tar_model.name \
                                   + '-adv-' + ams + '_' + am + '_' + str(epsilon) + '.csv'
                    dgan.importAdv(
                        clean_img_path=clean_img_path,
                        adv_img_path=adv_img_path,
                    )
                    print(ams + '_' + am + '_' + str(epsilon),end=' ')
                    dist_raw, dist_adv, _, _=dgan.detect_adv(
                        plot_flag=False
                    )
                    dl_raw.append(dist_raw)
                    dl_adv.append(dist_adv)
                bp_raw=plt.boxplot(
                    dl_raw, 
                    patch_artist=True,
                    showfliers=False,
                    labels=d_model_name
                )
                for box in bp_raw['boxes']:
                    box.set(color='blue', linewidth=2)
                    box.set(facecolor = 'blue' )
                    box.set(alpha = 0.3)
                bp_adv=plt.boxplot(
                    dl_adv, 
                    patch_artist=True,
                    showfliers=False,
                    labels=d_model_name,
                    # fontsize=26
                )
                for box in bp_adv['boxes']:
                    box.set(color='red', linewidth=2)
                    box.set(facecolor = 'red' )
                    box.set(alpha = 0.3)
                plt.ylabel('Reconstruction Error',fontsize=22)
                plt.xticks(fontsize=22)
                plt.yticks(fontsize=22)
                plt.savefig(tar_model.saveImgPath + '/' +ams + '_' + am + '_' + str(epsilon) + '.pdf')
                plt.clf()
                plt.close()
                
def plot_ROC():
    plt.style.use('seaborn-whitegrid')
    for ams in attackMethods:
        for am in attackMethods[ams]:
            epsilons = attackMethods[ams][am]['epsilon'][dataset.name]
            for epsilon in epsilons:
                dl_raw=[]
                dl_adv=[]
                print(ams + '_' + am + '_' + str(epsilon),end=' ')
                for idx in range(len(d_model_list)):
                    dgan=d_model_list[idx]
                    clean_img_path = 'data//' + dgan.tar_model.name \
                                     + '-clean-' + ams + '_' + am + '_' + str(epsilon) + '.csv'#_1000
                    adv_img_path = 'data//' + dgan.tar_model.name \
                                   + '-adv-' + ams + '_' + am + '_' + str(epsilon) + '.csv'
                    dgan.importAdv(
                        clean_img_path=clean_img_path,
                        adv_img_path=adv_img_path,
                    )
                    dist_raw, dist_adv, fn, fp=dgan.detect_adv(
                        plot_flag=False
                    )
                    fn=np.array(fn)
                    fp=np.array(fp)
                    tp=1-fn
                    plt.plot(fp,tp,marker=idx+4,label=d_model_name[idx])
                plt.ylabel('TPR', fontsize=22)
                plt.xlabel('FPR', fontsize=22)
                plt.xticks(fontsize=22)
                plt.yticks(fontsize=22)
                plt.legend(fontsize=22)
                plt.tight_layout()
                plt.savefig(tar_model.saveImgPath + '/' +ams + '_' + am + '_' + str(epsilon) + '.pdf')
                plt.clf()

if __name__=='__main__':
    plot_box()
    plot_ROC()
    plot_all_image()
